Skip to content

[ADT] Bitset: add shift operators, word accessors, and etc#193400

Open
JiachenYuan wants to merge 1 commit intollvm:mainfrom
JiachenYuan:perf/jiachen/bitset_prepare_for_lbm
Open

[ADT] Bitset: add shift operators, word accessors, and etc#193400
JiachenYuan wants to merge 1 commit intollvm:mainfrom
JiachenYuan:perf/jiachen/bitset_prepare_for_lbm

Conversation

@JiachenYuan
Copy link
Copy Markdown
Contributor

This PR is split out from #191757 per reviewer request. It has the following changes to llvm::Bitset<N>:

  • Added operator<</<<=/>>/>>=, getNumWords(), getWord(), and findLastSet().
  • Moved the std::array<> constructor from protected to public and explicit.

A follow-up PR will use these to re-implement LaneBitmask as a llvm::Bitset wrapper.


The unit test in the PR is largely generated by LLMs. I have reviewed it and manually applied changes to cover more edge cases.

@JiachenYuan JiachenYuan marked this pull request as ready for review April 22, 2026 03:55
@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Apr 22, 2026

@llvm/pr-subscribers-llvm-adt

Author: Jiachen Yuan (JiachenYuan)

Changes

This PR is split out from #191757 per reviewer request. It has the following changes to llvm::Bitset&lt;N&gt;:

  • Added operator&lt;&lt;/&lt;&lt;=/&gt;&gt;/&gt;&gt;=, getNumWords(), getWord(), and findLastSet().
  • Moved the std::array&lt;&gt; constructor from protected to public and explicit.

A follow-up PR will use these to re-implement LaneBitmask as a llvm::Bitset wrapper.


The unit test in the PR is largely generated by LLMs. I have reviewed it and manually applied changes to cover more edge cases.


Full diff: https://github.com/llvm/llvm-project/pull/193400.diff

2 Files Affected:

  • (modified) llvm/include/llvm/ADT/Bitset.h (+89-4)
  • (modified) llvm/unittests/ADT/BitsetTest.cpp (+198)
diff --git a/llvm/include/llvm/ADT/Bitset.h b/llvm/include/llvm/ADT/Bitset.h
index 9dc0f24b1d9f5..3cb2b7d28d83b 100644
--- a/llvm/include/llvm/ADT/Bitset.h
+++ b/llvm/include/llvm/ADT/Bitset.h
@@ -51,8 +51,9 @@ template <unsigned NumBits> class Bitset {
 
   constexpr void maskLastWord() { Bits[getLastWordIndex()] &= RemainderMask; }
 
-protected:
-  constexpr Bitset(const std::array<uint64_t, (NumBits + 63) / 64> &B) {
+public:
+  explicit constexpr Bitset(
+      const std::array<uint64_t, (NumBits + 63) / 64> &B) {
     if constexpr (sizeof(BitWord) == sizeof(uint64_t)) {
       for (size_t I = 0; I != B.size(); ++I)
         Bits[I] = B[I];
@@ -70,8 +71,6 @@ template <unsigned NumBits> class Bitset {
     }
     maskLastWord();
   }
-
-public:
   constexpr Bitset() = default;
   constexpr Bitset(std::initializer_list<unsigned> Init) {
     for (auto I : Init)
@@ -194,6 +193,92 @@ template <unsigned NumBits> class Bitset {
     }
     return false;
   }
+
+  constexpr Bitset &operator<<=(unsigned N) {
+    if (N == 0)
+      return *this;
+    if (N >= NumBits) {
+      return *this = Bitset();
+    }
+    const unsigned WordShift = N / BitwordBits;
+    const unsigned BitShift = N % BitwordBits;
+    if (BitShift == 0) {
+      for (int I = NumWords - 1; I >= static_cast<int>(WordShift); --I)
+        Bits[I] = Bits[I - WordShift];
+    } else {
+      const unsigned CarryShift = BitwordBits - BitShift;
+      for (int I = NumWords - 1; I > static_cast<int>(WordShift); --I) {
+        Bits[I] = (Bits[I - WordShift] << BitShift) |
+                  (Bits[I - WordShift - 1] >> CarryShift);
+      }
+      Bits[WordShift] = Bits[0] << BitShift;
+    }
+    for (unsigned I = 0; I < WordShift; ++I)
+      Bits[I] = 0;
+    maskLastWord();
+    return *this;
+  }
+
+  constexpr Bitset operator<<(unsigned N) const {
+    Bitset Result(*this);
+    Result <<= N;
+    return Result;
+  }
+
+  constexpr Bitset &operator>>=(unsigned N) {
+    if (N == 0)
+      return *this;
+    if (N >= NumBits) {
+      return *this = Bitset();
+    }
+    const unsigned WordShift = N / BitwordBits;
+    const unsigned BitShift = N % BitwordBits;
+    if (BitShift == 0) {
+      for (unsigned I = 0; I < NumWords - WordShift; ++I)
+        Bits[I] = Bits[I + WordShift];
+    } else {
+      const unsigned CarryShift = BitwordBits - BitShift;
+      for (unsigned I = 0; I < NumWords - WordShift - 1; ++I) {
+        Bits[I] = (Bits[I + WordShift] >> BitShift) |
+                  (Bits[I + WordShift + 1] << CarryShift);
+      }
+      Bits[NumWords - WordShift - 1] = Bits[NumWords - 1] >> BitShift;
+    }
+    for (unsigned I = NumWords - WordShift; I < NumWords; ++I)
+      Bits[I] = 0;
+    maskLastWord();
+    return *this;
+  }
+
+  constexpr Bitset operator>>(unsigned N) const {
+    Bitset Result(*this);
+    Result >>= N;
+    return Result;
+  }
+
+  /// Return the I-th 64-bit word of the bitset, from least significant to most.
+  constexpr uint64_t getWord(unsigned I) const {
+    if constexpr (BitwordBits == 64) {
+      return Bits[I];
+    } else {
+      static_assert(BitwordBits == 32, "Unsupported word size");
+      uint64_t Lo = (2 * I < NumWords) ? Bits[2 * I] : 0;
+      uint64_t Hi = (2 * I + 1 < NumWords) ? Bits[2 * I + 1] : 0;
+      return Lo | (Hi << 32);
+    }
+  }
+
+  /// Return the index of the highest set bit, or -1 if no bits are set.
+  constexpr int findLastSet() const {
+    for (int I = NumWords - 1; I >= 0; --I)
+      if (Bits[I] != 0)
+        return I * BitwordBits +
+               (BitwordBits - 1 - countl_zero_constexpr(Bits[I]));
+    return -1;
+  }
+
+  /// Return the number of 64-bit words needed to hold all bits.
+  static constexpr unsigned getNumWords() { return (NumBits + 63) / 64; }
 };
 
 } // end namespace llvm
diff --git a/llvm/unittests/ADT/BitsetTest.cpp b/llvm/unittests/ADT/BitsetTest.cpp
index 678197e31a379..ee3ef07d01979 100644
--- a/llvm/unittests/ADT/BitsetTest.cpp
+++ b/llvm/unittests/ADT/BitsetTest.cpp
@@ -294,4 +294,202 @@ TEST(BitsetTest, BitwiseOperators) {
                 TestXor128.test(127));
 }
 
+TEST(BitsetTest, ShiftOperators) {
+  // Test left shift.
+  static_assert((Bitset<64>({0}) << 10).test(10));
+  static_assert(!(Bitset<64>({0}) << 10).test(0));
+  static_assert((Bitset<64>({63}) << 1).none());
+  static_assert((Bitset<128>({0}) << 64).test(64));
+  static_assert((Bitset<128>({63}) << 1).test(64));
+  static_assert((Bitset<128>({127}) << 1).none());
+
+  // Test right shift.
+  static_assert((Bitset<64>({10}) >> 10).test(0));
+  static_assert(!(Bitset<64>({10}) >> 10).test(10));
+  static_assert((Bitset<64>({0}) >> 1).none());
+  static_assert((Bitset<128>({64}) >> 64).test(0));
+  static_assert((Bitset<128>({64}) >> 1).test(63));
+  static_assert((Bitset<128>({0}) >> 1).none());
+
+  // Test shift by 0.
+  static_assert((Bitset<64>({10, 20}) << 0) == Bitset<64>({10, 20}));
+  static_assert((Bitset<64>({10, 20}) >> 0) == Bitset<64>({10, 20}));
+
+  // Test shift by NumBits (clears all).
+  static_assert((Bitset<64>({0, 63}) << 64).none());
+  static_assert((Bitset<64>({0, 63}) >> 64).none());
+  static_assert((Bitset<128>({0, 127}) << 128).none());
+  static_assert((Bitset<128>({0, 127}) >> 128).none());
+}
+
+TEST(BitsetTest, GetNumWords64) {
+  static_assert(Bitset<1>::getNumWords() == 1);
+  static_assert(Bitset<32>::getNumWords() == 1);
+  static_assert(Bitset<64>::getNumWords() == 1);
+  static_assert(Bitset<65>::getNumWords() == 2);
+  static_assert(Bitset<96>::getNumWords() == 2);
+  static_assert(Bitset<128>::getNumWords() == 2);
+  static_assert(Bitset<129>::getNumWords() == 3);
+}
+
+TEST(BitsetTest, GetWord) {
+  // Single-word bitset.
+  constexpr auto B64 = Bitset<64>(std::array<uint64_t, 1>{0xdeadbeefcafe1234});
+  static_assert(B64.getWord(0) == 0xdeadbeefcafe1234);
+
+  // Multi-word bitset.
+  constexpr auto B128 = Bitset<128>(
+      std::array<uint64_t, 2>{0x1111222233334444, 0xaaaabbbbccccdddd});
+  static_assert(B128.getWord(0) == 0x1111222233334444);
+  static_assert(B128.getWord(1) == 0xaaaabbbbccccdddd);
+
+  // Partial last word — high bits should be masked off.
+  constexpr auto B96 = Bitset<96>(
+      std::array<uint64_t, 2>{0xffffffffffffffff, 0xffffffffffffffff});
+  static_assert(B96.getWord(0) == 0xffffffffffffffff);
+  // Only lower 32 bits.
+  static_assert(B96.getWord(1) == 0x00000000ffffffff);
+
+  // Empty bitset.
+  static_assert(Bitset<64>().getWord(0) == 0);
+  static_assert(Bitset<128>().getWord(0) == 0);
+  static_assert(Bitset<128>().getWord(1) == 0);
+}
+
+TEST(BitsetTest, FindLastSet) {
+  // Empty bitset returns -1.
+  static_assert(Bitset<64>().findLastSet() == -1);
+  static_assert(Bitset<128>().findLastSet() == -1);
+
+  // Single bit set.
+  static_assert(Bitset<64>({0}).findLastSet() == 0);
+  static_assert(Bitset<64>({63}).findLastSet() == 63);
+  static_assert(Bitset<64>({31}).findLastSet() == 31);
+  static_assert(Bitset<128>({0}).findLastSet() == 0);
+  static_assert(Bitset<128>({64}).findLastSet() == 64);
+  static_assert(Bitset<128>({127}).findLastSet() == 127);
+
+  // Multiple bits — returns highest.
+  static_assert(Bitset<64>({0, 10, 50}).findLastSet() == 50);
+  static_assert(Bitset<128>({0, 63, 64, 100}).findLastSet() == 100);
+
+  // All bits set.
+  static_assert(Bitset<64>().set().findLastSet() == 63);
+  static_assert(Bitset<128>().set().findLastSet() == 127);
+  static_assert(Bitset<96>().set().findLastSet() == 95);
+
+  // Non-power-of-2 sizes.
+  static_assert(Bitset<33>({32}).findLastSet() == 32);
+  static_assert(Bitset<33>({0, 32}).findLastSet() == 32);
+  static_assert(Bitset<65>({64}).findLastSet() == 64);
+}
+
+TEST(BitsetTest, ShiftMultiWords) {
+  constexpr auto B192 = Bitset<192>({0, 64, 128});
+  static_assert((B192 << 1) == Bitset<192>({1, 65, 129}));
+  static_assert((B192 >> 1) == Bitset<192>({63, 127}));
+  static_assert((B192 << 64) == Bitset<192>({64, 128}));
+  static_assert((B192 >> 64) == Bitset<192>({0, 64}));
+  static_assert((Bitset<192>({63, 127}) << 1) == Bitset<192>({64, 128}));
+  static_assert((Bitset<192>({64, 128}) >> 1) == Bitset<192>({63, 127}));
+}
+
+TEST(BitsetTest, ShiftBoundaryBitShifts) {
+  static_assert((Bitset<128>({1}) << 63) == Bitset<128>({64}));
+  static_assert((Bitset<128>({64}) >> 63) == Bitset<128>({1}));
+  static_assert((Bitset<192>({1, 65}) << 63) == Bitset<192>({64, 128}));
+  // Shift by NumBits - 1.
+  static_assert((Bitset<64>({0}) << 63) == Bitset<64>({63}));
+  static_assert((Bitset<64>({63}) >> 63) == Bitset<64>({0}));
+  static_assert((Bitset<33>({0}) << 32) == Bitset<33>({32}));
+  // Full-width shift of a fully-set bitset loses exactly one bit.
+  static_assert((Bitset<128>().set() << 1).count() == 127);
+  static_assert((Bitset<128>().set() >> 1).count() == 127);
+  static_assert((Bitset<100>().set() >> 1).count() == 99);
+}
+
+TEST(BitsetTest, ShiftExcessAmount) {
+  static_assert((Bitset<64>().set() << 65).none());
+  static_assert((Bitset<64>().set() >> 200).none());
+  static_assert((Bitset<33>({0, 10, 32}) << 1000).none());
+  static_assert((Bitset<128>({0, 127}) >> 1000).none());
+  static_assert((Bitset<192>().set() << 193).none());
+}
+
+TEST(BitsetTest, ShiftAssignReturnsReference) {
+  constexpr Bitset<64> L = [] {
+    Bitset<64> X({0});
+    (X <<= 3) <<= 2;
+    return X;
+  }();
+  static_assert(L == Bitset<64>({5}));
+
+  constexpr Bitset<128> R = [] {
+    Bitset<128> X({100});
+    (X >>= 30) >>= 10;
+    return X;
+  }();
+  static_assert(R == Bitset<128>({60}));
+}
+
+TEST(BitsetTest, GetWordConsistencyWithTest) {
+  // For every set bit, getWord must report it in the expected 64-bit word.
+  constexpr auto B100 = Bitset<100>({0, 50, 64, 99});
+  static_assert((B100.getWord(0) & 1) != 0);
+  static_assert((B100.getWord(0) & (uint64_t(1) << 50)) != 0);
+  static_assert((B100.getWord(1) & 1) != 0);
+  static_assert((B100.getWord(1) & (uint64_t(1) << 35)) != 0);
+}
+
+TEST(BitsetTest, GetWordAfterMutation) {
+  // getWord reflects subsequent set / shift.
+  constexpr auto B = [] {
+    Bitset<128> X;
+    X.set(5).set(70);
+    return X;
+  }();
+  static_assert(B.getWord(0) == (uint64_t(1) << 5));
+  static_assert(B.getWord(1) == (uint64_t(1) << 6));
+
+  constexpr auto Shifted = Bitset<128>({5}) << 64;
+  static_assert(Shifted.getWord(0) == 0);
+  static_assert(Shifted.getWord(1) == (uint64_t(1) << 5));
+}
+
+TEST(BitsetTest, GetNumWordsMoreWidths) {
+  static_assert(Bitset<2>::getNumWords() == 1);
+  static_assert(Bitset<192>::getNumWords() == 3);
+  static_assert(Bitset<193>::getNumWords() == 4);
+  static_assert(Bitset<256>::getNumWords() == 4);
+}
+
+TEST(BitsetTest, FindLastSetSmallWidths) {
+  static_assert(Bitset<1>().findLastSet() == -1);
+  static_assert(Bitset<1>({0}).findLastSet() == 0);
+  static_assert(Bitset<2>({0, 1}).findLastSet() == 1);
+  static_assert(Bitset<32>({31}).findLastSet() == 31);
+  static_assert(Bitset<32>().set().findLastSet() == 31);
+}
+
+TEST(BitsetTest, FindLastSetMultiWordScan) {
+  static_assert(Bitset<192>({70}).findLastSet() == 70);
+  static_assert(Bitset<192>({64, 70, 127}).findLastSet() == 127);
+  static_assert(Bitset<192>({3}).findLastSet() == 3);
+  static_assert(Bitset<100>({99}).findLastSet() == 99);
+}
+
+TEST(BitsetTest, FindLastSetAfterMutation) {
+  constexpr auto A = Bitset<128>({0, 50, 100}).reset(100);
+  static_assert(A.findLastSet() == 50);
+
+  constexpr auto B = Bitset<64>({10}) << 20;
+  static_assert(B.findLastSet() == 30);
+
+  constexpr auto C = Bitset<64>({63}) >> 10;
+  static_assert(C.findLastSet() == 53);
+
+  constexpr auto D = Bitset<64>({63}) << 1;
+  static_assert(D.findLastSet() == -1);
+}
+
 } // namespace

@JiachenYuan JiachenYuan force-pushed the perf/jiachen/bitset_prepare_for_lbm branch from 7f36829 to 657c2db Compare April 22, 2026 20:20
@JiachenYuan
Copy link
Copy Markdown
Contributor Author

Adding @arsenm and @s-barannikov for viz. Thank you!

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Apr 22, 2026

🐧 Linux x64 Test Results

  • 194303 tests passed
  • 5132 tests skipped

✅ The build succeeded and all tests passed.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Apr 22, 2026

🪟 Windows x64 Test Results

  • 133903 tests passed
  • 3153 tests skipped

✅ The build succeeded and all tests passed.

@JiachenYuan
Copy link
Copy Markdown
Contributor Author

Failure related to this: #193558. Rebasing and testing again.

@JiachenYuan JiachenYuan force-pushed the perf/jiachen/bitset_prepare_for_lbm branch from 657c2db to 7cbc4b4 Compare April 22, 2026 21:53
Comment thread llvm/include/llvm/ADT/Bitset.h Outdated
}

/// Return the number of 64-bit words needed to hold all bits.
static constexpr unsigned getNumWords() { return (NumBits + 63) / 64; }
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just return NumWords?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this question is related to the one below, and I tried to reply to both here.

Comment thread llvm/include/llvm/ADT/Bitset.h Outdated
}

/// Return the I-th 64-bit word of the bitset, from least significant to most.
constexpr uint64_t getWord(unsigned I) const {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like the naming is confusing because getWord would suggest that we just get the I-th word from the array, which would not require any operation other than Bits[I] and wouldn't need bitsize inspection. I would expect the return value of a getWord routine to be BitWord

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BitwordBits is architecture-dependent -- it could be 32-bit or 64-bit. Instead of making both getWord and getNumWords thin getters, my intention was that we provide a normalized 64-bit view to the external accessors. This way, we are not exposing every implementation detail to the consumers of Bitset. On the other hand, I think the names are indeed a little bit confusing. Would it make more sense if I rename them to be getWord64() and getNumWords64()?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I guess that would make more sense

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I have changed the function names accordingly.

/// Return the I-th 64-bit word of the bitset, from least significant to most.
constexpr uint64_t getWord(unsigned I) const {
if constexpr (BitwordBits == 64) {
return Bits[I];
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could also do some check or workaround for index-out-of-bounds, like you did for the 32bit case

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, thanks for catching this! I added an assertion to check index-out-of-bounds.

@JiachenYuan JiachenYuan force-pushed the perf/jiachen/bitset_prepare_for_lbm branch from 7cbc4b4 to 27bc031 Compare April 24, 2026 17:18
@JiachenYuan JiachenYuan force-pushed the perf/jiachen/bitset_prepare_for_lbm branch from 27bc031 to e9a137f Compare April 24, 2026 17:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants